import os
import re
import json
import time
from datetime import datetime
from tqdm import tqdm
from call_gpt import call_gpt
from prompts.creative_writing import standard_prompt, cot_prompt, basm_prompt
from prompts.creative_writing import referee_prompt
import argparse

parser = argparse.ArgumentParser(description='Run creative writing task')

parser.add_argument('--dataset_dir', type=str, default='../dataset', help='Directory of the dataset')
parser.add_argument('--model', type=str, default='gpt-3.5-turbo', help='Model to use')
parser.add_argument('--method', type=str, default='standard', choices=['standard', 'cot', 'basm'], help='Method to use')
parser.add_argument('--max-tokens', type=int, default=1000, help='Max tokens of generated story')
parser.add_argument('--log_dir', type=str, default='log', help='Directory for logs')
parser.add_argument('--phase', type=str, default='generation', choices=['generation', 'evaluation'], help='If set to evaluation, the model will run the automatic evaluation program using a gpt4 referee')
parser.add_argument('--evaluate_dirs', nargs='+', help='A list of evaluate result directories', type=str, default=[])

args = parser.parse_args()
dataset_dir = args.dataset_dir
task = 'creative_writing'
model = args.model
method = args.method
log_dir = args.log_dir
max_tokens = args.max_tokens
phase = args.phase

method = method.lower()

log_dir_base = os.path.join(log_dir, task)


def generate():
    current_time_str = datetime.now().strftime('%Y%m%d_%H%M%S')

    if not os.path.exists(log_dir_base):
        os.makedirs(log_dir_base)

    log_dir = os.path.join(log_dir_base, f'{model}_{method}_{current_time_str}')
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    metadata_path = os.path.join(dataset_dir, task, 'task.json')
    with open(metadata_path, 'r', encoding='utf8') as f:
        metadata = json.load(f)

    for item in tqdm(metadata):
        temperature = 0.7
        if method == 'standard':
            prompt = standard_prompt
        elif method == 'cot':
            prompt = cot_prompt
        elif method == 'basm':
            prompt = cot_prompt

        result = call_gpt(
            prompt.format(story_prompt=item['input']), 
            model,
            temperature=temperature,
            max_tokens=max_tokens,
        )

        with open(os.path.join(log_dir, item['input'][:64]), 'w', encoding='utf8') as f:
            f.write(result)

        time.sleep(5)


def compare_better_story(story1, story2):
    referee_model = 'gpt-3.5-turbo'
    result = call_gpt(
        referee_prompt.format(first_story=story1, second_story=story2), 
        model=referee_model,
        temperature=0,
        max_tokens=1,
    )
    try:
        better_id_1 = int(result)
    except:
        better_id_1 = -1

    print('=' * 20)
    print(better_id_1)
    
    result = call_gpt(
        referee_prompt.format(first_story=story2, second_story=story1), 
        model=referee_model,
        temperature=0,
        max_tokens=1,
    )
    try:
        better_id_2 = int(result)
    except:
        better_id_2 = -1

    print(better_id_2)

    result_map = {
        (-1, -1): -1,
        (1, -1): 1,
        (-1, 1): 2,
        (2, -1): 2,
        (-1, 2): 1,
        (1, 1): 0,
        (2, 2): 0,
        (1, 2): 1,
        (2, 1): 2
    }
    return result_map[(better_id_1, better_id_2)]


def evaluate():
    metadata = {}
    methods = []
    for evaluate_dir in args.evaluate_dirs:
        method = evaluate_dir.split('_')[1]
        metadata[method] = {
            'method': method,
            'story_dir': os.path.join(log_dir_base, evaluate_dir),
            'story_filenames': os.listdir(os.path.join(log_dir_base, evaluate_dir))
        }
        methods.append(method)
    assert len(methods) == len(set(methods)), 'Method must be unique, please check!'

    evaluate_result = {}
    for i, method1 in enumerate(methods):
        evaluate_result[method1] = {}
        for j, method2 in enumerate(tqdm(methods)):
            if j > i:
                common_filenames = set(metadata[method1]['story_filenames']) & set(metadata[method2]['story_filenames'])
                num_invalid = 0
                num_tie = 0
                num_win1 = 0
                num_win2 = 0
                for filename in common_filenames:
                    print(f'method1 {method1} vs method2 {method2} on instance {filename}')
                    with open(os.path.join(metadata[method1]['story_dir'], filename), 'r', encoding='utf8') as f:
                        story1 = f.read()
                    with open(os.path.join(metadata[method2]['story_dir'], filename), 'r', encoding='utf8') as f:
                        story2 = f.read()
                    r = compare_better_story(story1, story2)
                    if r == -1:
                        num_invalid += 1
                    elif r == 0:
                        num_tie += 1
                    elif r == 1:
                        num_win1 += 1
                    elif r == 2:
                        num_win2 += 1
                evaluate_result[method1][method2] = {
                    'num_invalid': num_invalid,
                    'num_tie': num_tie,
                    'num_win1': num_win1,
                    'num_win2': num_win2,
                }

    print(evaluate_result)


if __name__ == '__main__':
    if phase == 'generation':
        generate()
    elif phase == 'evaluation':
        evaluate()
